# runners/gp_runner.py
from __future__ import annotations
import time
import optuna
import ConfigSpace as CS

from objective import Objective
from loggers import ExperimentLogger

from runners.random_runner import _suggest_from_configspace


def run_gp_optuna(*,
                  seed: int,
                  bench: str,
                  cs: CS.ConfigurationSpace,
                  obj: Objective,
                  budget_n: int,
                  logger: ExperimentLogger,
                  method_name: str = "GPBO",
                  n_startup_trials: int | None = None,
                  deterministic_objective: bool = True):
    """
    Gaussian Process Bayesian Optimization using Optuna GPSampler.
    - Internal optimization objective: minimize(loss)
    - Logging: best_score = 1 - best_loss, curr_score = 1 - curr_loss
    - Hyperparameter suggestion: reuse random_runner._suggest_from_configspace (handles canonicalization and conditions)
    """

    # Default startup: 10% of budget (at least 10)
    if n_startup_trials is None:
        n_startup_trials = max(10, budget_n // 10)

    # Construct GPSampler (compatible with parameter differences in different Optuna versions)
    sampler_kwargs = dict(
        seed=seed,
        n_startup_trials=n_startup_trials,
    )
    # Newer Optuna supports deterministic_objective
    try:
        sampler_kwargs["deterministic_objective"] = deterministic_objective
    except TypeError:
        pass

    sampler = optuna.samplers.GPSampler(**sampler_kwargs)

    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.NopPruner(),
    )

    best = float("inf")

    def objective(trial: optuna.Trial):
        # ✅ Directly reuse stable suggestion function (handles Constant / NormalFloat / conditional dependencies / canonicalization)
        cfg = _suggest_from_configspace(trial, cs)

        t0 = time.perf_counter()
        curr, sim_t = obj.evaluate(cfg)      # curr = loss
        elapsed = time.perf_counter() - t0

        trial.set_user_attr("sim_time", sim_t)
        trial.set_user_attr("elapsed_time", elapsed)
        trial.set_user_attr("config", cfg)
        return curr

    def cb(study: optuna.Study, trial: optuna.FrozenTrial):
        if trial.value is None:
            return
        nonlocal best
        n = len([t for t in study.trials if t.value is not None])
        curr = trial.value
        best = min(best, curr)
        logger.log(dict(
            seed=seed, method=method_name, bench=bench,
            n_eval=n,
            sim_time=trial.user_attrs.get("sim_time", 0.0),
            elapsed_time=trial.user_attrs.get("elapsed_time", 0.0),
            best_score=1 - best,     # Consistent with other runners
            curr_score=1 - curr,
            config=trial.user_attrs.get("config", {}),
        ))

    study.optimize(objective, n_trials=budget_n, callbacks=[cb], show_progress_bar=False)
    return study.best_trial

